from __future__ import annotations

import os
import warnings
import typing_extensions
import typing
import functools

from baml_py.logging import (
    get_log_level as baml_get_log_level,
    set_log_level as baml_set_log_level,
)
from .globals import reset_baml_env_vars

rT = typing_extensions.TypeVar("rT")  # return type
pT = typing_extensions.ParamSpec("pT")  # parameters type


def _deprecated(message: str):
    def decorator(func: typing.Callable[pT, rT]) -> typing.Callable[pT, rT]:
        """Use this decorator to mark functions as deprecated.
        Every time the decorated function runs, it will emit
        a "deprecation" warning."""

        @functools.wraps(func)
        def new_func(*args: pT.args, **kwargs: pT.kwargs):
            warnings.simplefilter("always", DeprecationWarning)  # turn off filter
            warnings.warn(
                "Call to a deprecated function {}.".format(func.__name__) + message,
                category=DeprecationWarning,
                stacklevel=2,
            )
            warnings.simplefilter("default", DeprecationWarning)  # reset filter
            return func(*args, **kwargs)

        return new_func

    return decorator


@_deprecated("Use os.environ['BAML_LOG'] instead")
def get_log_level():
    """
    Get the log level for the BAML Python client.
    """
    return baml_get_log_level()


@_deprecated("Use os.environ['BAML_LOG'] instead")
def set_log_level(
    level: typing_extensions.Literal["DEBUG", "INFO", "WARN", "ERROR", "OFF"] | str,
):
    """
    Set the log level for the BAML Python client
    """
    baml_set_log_level(level)
    os.environ["BAML_LOG"] = level


@_deprecated("Use os.environ['BAML_LOG_JSON_MODE'] instead")
def set_log_json_mode():
    """
    Set the log JSON mode for the BAML Python client.
    """
    os.environ["BAML_LOG_JSON_MODE"] = "true"


@_deprecated("Use os.environ['BAML_LOG_MAX_CHUNK_LENGTH'] instead")
def set_log_max_chunk_length():
    """
    Set the maximum log chunk length for the BAML Python client.
    """
    os.environ["BAML_LOG_MAX_CHUNK_LENGTH"] = "1000"


def set_log_max_message_length(*args, **kwargs):
    """
    Alias for set_log_max_chunk_length for compatibility with docs.
    """
    return set_log_max_chunk_length(*args, **kwargs)


__all__ = [
    "set_log_level",
    "get_log_level",
    "set_log_json_mode",
    "reset_baml_env_vars",
    "set_log_max_message_length",
    "set_log_max_chunk_length",
]
